{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from polygraphy.backend.trt import CreateConfig, Profile, TrtRunner, TacticRecorder, TacticReplayer, TacticReplayData\n",
    "from polygraphy.backend.trt import network_from_onnx_path, engine_from_network, create_config, save_engine\n",
    "\n",
    "from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx\n",
    "from polygraphy.backend.onnx import modify_outputs, onnx_from_path, save_onnx\n",
    "from polygraphy.backend.onnx.util import str_from_onnx, all_tensor_names\n",
    "\n",
    "from polygraphy.comparator import Comparator, CompareFunc, DataLoader\n",
    "from polygraphy.common import TensorMetadata\n",
    "from polygraphy.json import load_json\n",
    "\n",
    "\n",
    "ONNX_MODEL = \"./models/crnn.onnx\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inspect the onnx model\n",
    "onnx_proto = onnx_from_path(ONNX_MODEL)\n",
    "onnx_str = str_from_onnx(onnx_proto, mode='attrs')\n",
    "print(onnx_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Modify the onnx model's outputs\n",
    "\n",
    "modified = modify_outputs(onnx_proto,['646','351'])# mark tensor 646 and 351 as outputs, or you can use constants.MARK_ALL to mark all tensors as outputs\n",
    "onnx_modified = str_from_onnx(modified, mode='attrs')\n",
    "print(onnx_modified)\n",
    "_=save_onnx(modified,\"modified.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builder, network, parser = network_from_onnx_path(ONNX_MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "recorder = TacticRecorder(\"trt_tactics.json\")\n",
    "optimization_profiles = [Profile().add(\"input_0\", min=(1,3,32,1024), opt=(8,3,32,1024), max=(16,3,32,1024))]\n",
    "trt_config = CreateConfig(max_workspace_size=4096000000,\n",
    "                          tf32=None,\n",
    "                          fp16=True,\n",
    "                          int8=None,\n",
    "                          profiles=optimization_profiles,\n",
    "                          calibrator=None,\n",
    "                          strict_types=None, \n",
    "                          load_timing_cache=None,\n",
    "                          algorithm_selector=recorder,\n",
    "                          sparse_weights=None,\n",
    "                          tactic_sources=[0,1,2],\n",
    "                          restricted=None)\n",
    "\n",
    "engine = engine_from_network((builder,network), config=trt_config)\n",
    "save_engine(engine,\"crnn_fp16.plan\")\n",
    "\n",
    "# To ensure the reproducibility of the engine build\n",
    "\n",
    "# replayer = TacticReplayer(\"trt_tactics.json\")\n",
    "# optimization_profiles = [Profile().add(\"input_0\", min=(1,3,32,1024), opt=(8,3,32,1024), max=(16,3,32,1024))]\n",
    "# trt_config = CreateConfig(max_workspace_size=2048000000,\n",
    "#                          tf32=None,\n",
    "#                          fp16=True,\n",
    "#                          int8=None,\n",
    "#                          profiles=optimization_profiles,\n",
    "#                          calibrator=None,\n",
    "#                          strict_types=None, \n",
    "#                          load_timing_cache=None,\n",
    "#                          algorithm_selector=replayer,\n",
    "#                          sparse_weights=None,\n",
    "#                          tactic_sources=[0,1,2],\n",
    "#                          restricted=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run inference with comparator\n",
    "\n",
    "import numpy as np\n",
    "  \n",
    "meta = TensorMetadata()\n",
    "meta.add(\"input_0\", np.float32, [16,3,32,1024])\n",
    "loader = DataLoader(input_metadata=meta)\n",
    "\n",
    "onnx_session = session_from_onnx(ONNX_MODEL)\n",
    "\n",
    "runners = [TrtRunner(engine),\n",
    "           OnnxrtRunner(onnx_session)]\n",
    "\n",
    "run_results = Comparator.run(runners, data_loader=loader, save_inputs_path=\"inputs.json\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compare results across runners\n",
    "\n",
    "compare = CompareFunc.basic_compare_func(check_shapes=True, rtol=10, atol=0.15, fail_fast=None, check_error_stat='max')\n",
    "\n",
    "accuracy_result = Comparator.compare_accuracy(run_results,compare_func=compare)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorrt as trt\n",
    "\n",
    "EXCLUDE_LAYERS = [trt.LayerType.SHAPE,trt.LayerType.CONSTANT,trt.LayerType.CONCATENATION,trt.LayerType.GATHER,trt.LayerType.SLICE,trt.LayerType.SHUFFLE]\n",
    "num_layer = network.num_layers\n",
    "for i in range(num_layer-1,255,-1):\n",
    "    layer = network.get_layer(i)\n",
    "    if layer.type not in EXCLUDE_LAYERS:\n",
    "        print(\"setting layer_{} to fp32\".format(i))      \n",
    "        layer.reset_precision()\n",
    "        layer.precision = trt.float32\n",
    "        layer.set_output_type(0,trt.float32)\n",
    "          \n",
    "trt_config = CreateConfig(max_workspace_size=4096000000,\n",
    "                          tf32=None,\n",
    "                          fp16=True,\n",
    "                          int8=None,\n",
    "                          profiles=optimization_profiles,\n",
    "                          calibrator=None,\n",
    "                          strict_types=True, \n",
    "                          load_timing_cache=None,\n",
    "                          algorithm_selector=None,\n",
    "                          sparse_weights=None,\n",
    "                          tactic_sources=[0,1,2],\n",
    "                          restricted=None)\n",
    "\n",
    "engine_v2 = engine_from_network((builder,network), config=trt_config)\n",
    "save_engine(engine_v2,\"crnn_fp16_v2.plan\")   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "runners_v2 = [TrtRunner(engine_v2),\n",
    "              OnnxrtRunner(onnx_session)]\n",
    "\n",
    "loader = []\n",
    "for input_data_path in ['inputs.json']:\n",
    "    loader.extend(load_json(input_data_path, description='input data'))\n",
    "\n",
    "run_results_v2 = Comparator.run(runners_v2, data_loader=loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_result = Comparator.compare_accuracy(run_results_v2,compare_func=compare)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Validate NaN and inf\n",
    "Comparator.validate(run_results_v2, check_inf=True, check_nan=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# direct access to the run_result of the TRTrunner\n",
    "run_results_v2[0][1][0]['output_0']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3777ac85f82c8401b3eebaa61eadb719a036e8abe3fddb34650d89b5e6ca4402"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
